import torch.nn as nn
import torch
from torchvision import transforms
import math
from matplotlib import image
from matplotlib import pyplot as plt
import numpy as np
import torch.onnx

class VGG(nn.Module):
    def __init__(self, features, num_classes=1000, init_weights=True):
        super(VGG, self).__init__()#python3中等效于super().__init__()
        self.features = features#前面的卷积结构各层函数，在下面的函数forward()被调用
        self.classifier = nn.Sequential(#从全连接到分类的结构层函数的一个顺序容器
            nn.Linear(512 * 7 * 7, 4096),#第一个全连接层，输入为[batch_size,512*7*7],输出大小为[batch_size,4096],bias默认为True
            nn.ReLU(True),#ReLU激活函数
            nn.Dropout(),#Dropout函数

            nn.Linear(4096, 4096),
            nn.ReLU(True),
            nn.Dropout(),
            #最后一层不需要添加激活函数
            nn.Linear(4096, num_classes),#最后一层全连接，输出为num_size，类别的个数
        )#通过Squential将网络层和激活函数和Dropout函数结合起来，输出激活后的网络节点。
        if init_weights:
            self._initialize_weights()
#备注：nn.Linear(in_features,out_feartures,bias=True)用于设置网络中的全连接层的，需要注意的是全连接层的输入与输出都是二维张量，一般形状为[batch_size, size]
#参数：in_features,out_feartures,bias
#in_features指的是输入的二维张量的大小，即输入的[batch_size, size]中的size。
#out_features指的是输出的二维张量的大小，即输出的二维张量的形状为[batch_size，output_size]，当然，它也代表了该全连接层的神经元个数。


    def forward(self, x):#x为输入的图片张量
        x = self.features(x)#卷积层顺序容器，输入x,输出经过所有卷积层后的特征层
        print(x,'-->',type(x),'-->',x.shape,'-->',x.dtype)#x=[batch_size,C,H,W]
        x = x.view(x.size(0), -1)#把x维度进行调整,保持batch数一致，[batch_size,一个展开后的特征层]
        x = self.classifier(x)#全连接层的顺序容器，最后输出num_classes,用于最后的类别判断
        return x

    def _initialize_weights(self):#权重初始化
        for m in self.modules():
            if isinstance(m, nn.Conv2d):  #判断m是否是nn.COnv2d类型
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
                if m.bias is not None:
                    m.bias.data.zero_()
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                m.weight.data.normal_(0, 0.01)
                m.bias.data.zero_()

def make_layers(cfg, batch_norm=False):#卷积层的实现函数
    layers = []
    in_channels = 3    #初始通道数为3，RGB
    for v in cfg:
        if v == 'M':#判断是否为池化层
            layers += [nn.MaxPool2d(kernel_size=2, stride=2)]#经过最大池化，W，H变为原来的1/2
        else:#卷积层 + /BN + ReLU
            conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)#卷积层(输入通道数,输出通道数,卷积核的大小,边缘补数)
            if batch_norm:#批处理规范化
                layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
            else:#不进行批处理
                layers += [conv2d, nn.ReLU(inplace=True)]
            in_channels = v   #更改通道数
    return nn.Sequential(*layers)#返回卷积部分的顺序容器

cfg = {
    'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
    'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
}#网络的卷积层池化层结构类型


def vgg11(**kwargs):
    model = VGG(make_layers(cfg['A']), **kwargs)#实例化类VGG,此处执行类的构造函数__init__
    #这里把VGG结构的第一个结构vgg11的卷积层部分，make_layers(cfg['A'])作为feature传入VGG类中
    return model #返回卷积层部分的顺序容器和全连接层部分的顺序容器


def vgg11_bn(**kwargs):#是在vgg11上添加batch_noorm
    model = VGG(make_layers(cfg['A'], batch_norm=True), **kwargs)
    return model


def vgg13(**kwargs):
    model = VGG(make_layers(cfg['B']), **kwargs)
    return model


def vgg13_bn(**kwargs):
    model = VGG(make_layers(cfg['B'], batch_norm=True), **kwargs)
    return model


def vgg16(**kwargs):
    model = VGG(make_layers(cfg['D']), **kwargs)
    return model


def vgg16_bn(**kwargs):
    model = VGG(make_layers(cfg['D'], batch_norm=True), **kwargs)
    return model


def vgg19(**kwargs):
    model = VGG(make_layers(cfg['E']), **kwargs)
    return model


def vgg19_bn(**kwargs):
    model = VGG(make_layers(cfg['E'], batch_norm=True), **kwargs)
    return model



if __name__ == '__main__':#执行本文件时，if条件满足，执行下面的语句

    input = torch.rand(1, 3, 244, 244)
    model = vgg11()
    output = model(input)
    torch.onnx.export(model, input, "netForwatch.onnx")



